from kamal import vision, engine, utils, amalgamation, metrics, callbacks
from kamal.vision import sync_transforms as sT
import pdb
import oneflow
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model")
parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model")
parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model")
parser.add_argument('--flower_ckpt', default="./ckpt/flower_res50_model")
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--batch_size', type=int, default=16)
args = parser.parse_args()


def main():
    # 数据集
    car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train')
    car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test')
    dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train')
    dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test')
    flower_train_dst = vision.datasets.Flowers102('./DataSets/Flower102/', split='train')
    flower_val_dst = vision.datasets.Flowers102('./DataSets/Flower102/', split='valid')

    # 教师/学生
    car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False)
    dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False)
    aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False)
    flower_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False)
    student = vision.models.classification.resnet50(num_classes=196+120+102+102, pretrained=False)

    # 权重参数
    cars_parameters = oneflow.load(args.car_ckpt)
    dogs_parameters = oneflow.load(args.dog_ckpt)
    aircraft_parameters = oneflow.load(args.aircraft_ckpt)
    flowers_parameters = oneflow.load(args.flower_ckpt)

    car_teacher.load_state_dict(cars_parameters)
    dog_teacher.load_state_dict(dogs_parameters)
    aircraft_teacher.load_state_dict(aircraft_parameters)
    flower_teacher.load_state_dict(flowers_parameters)

    train_transform = sT.Compose( [
                            sT.RandomResizedCrop(224),
                            sT.RandomHorizontalFlip(),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    val_transform = sT.Compose( [
                            sT.Resize(256),
                            sT.CenterCrop( 224 ),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )

    car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = flower_train_dst.transform = train_transform
    car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = flower_val_dst.transform = val_transform

    car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))})
    dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))})
    aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))})
    flower_metric = metrics.MetricCompose(metric_dict={'flower_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, 196+120+102:196+120+102+102], t ) ) } )

    train_dst = oneflow.utils.data.ConcatDataset([car_train_dst, dog_train_dst, aircraft_train_dst, flower_train_dst])
    train_loader = oneflow.utils.data.DataLoader(train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0)
    car_loader = oneflow.utils.data.DataLoader(car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0)
    dog_loader = oneflow.utils.data.DataLoader(dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0)
    aircraft_loader = oneflow.utils.data.DataLoader(aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True)
    flower_loader = oneflow.utils.data.DataLoader(flower_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0)

    car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric)
    dog_evaluator = engine.evaluator.BasicEvaluator(dog_loader, dog_metric)
    aircraft_evaluator = engine.evaluator.BasicEvaluator(aircraft_loader, aircraft_metric)
    flower_evaluator = engine.evaluator.BasicEvaluator(flower_loader, flower_metric)


    TOTAL_ITERS=len(train_loader) * 100
    device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' )
    optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4)
    sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )
    trainer = amalgamation.LayerWiseAmalgamator( 
        logger=utils.logger.get_logger('layerwise-ka'), 
        # tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr')))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[
            callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='tfl_car'),
            callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='tfl_dog'),
            callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='tfl_aircraft'),
            callbacks.EvalAndCkpt(model=student, evaluator=flower_evaluator, metric_name='flower_acc', ckpt_prefix='tfl_flower'),
        ] )
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    layer_groups = []
    layer_channels = []
    for stu_block, car_block, dog_block, aircraft_block, flower_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules(), flower_teacher.modules() ):
        if isinstance( stu_block, oneflow.nn.Conv2d ):
            layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block, flower_block ] )
            layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels, flower_block.out_channels ] )

    trainer.setup( student=student, 
                   teachers=[car_teacher, dog_teacher, aircraft_teacher, flower_teacher],
                   layer_groups=layer_groups,
                   layer_channels=layer_channels,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device,
                   weights=[1., 1., 1.] )
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)

if __name__=='__main__':
    main()